/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
 * Description: cloud_model_adapter.cpp
 */
#include "cloud_model_adapter.h"

#include <unistd.h>

// inc
#include "common/debug/log.h"
#include "infra/base/securestl.h"
#include "infra/base/process_util.h"
#include "infra/om/stats/ai_stats_log_builder.h"
#include "framework/graph/debug/ge_op_types.h"
#include "framework/common/types.h"
#include "framework/common/memory_allocator_factory.h"
#include "framework/compiler/model_compiler_factory.h"
#include "framework/compiler/compiled_model_factory.h"
#include "framework/compatible/ir_transformer.h"
#include "framework/graph/core/node/node_spec.h"

// src/framework
#include "base/common/cl_manager/ops_kernel_store_manager.h"
#include "general_compute/general_model_compiler.h"

// src/cls/dnnacl
#include "client/common/dnnacl_compiled_target.h"

using namespace hiai;
using namespace std;
using namespace ge;

#ifdef APPLIB
int HIAI_GraphOptimizer_Optimize(ge::GraphOptimizerOptions& options, ge::ComputeGraphPtr& graphPtr, int& stage)
{
    std::shared_ptr<CloudModelAdapter> modelAdapter = hiai::make_shared_nothrow<CloudModelAdapter>();
    if (modelAdapter == nullptr) {
        FMK_LOGE("cloud model adapter is null!");
        return FAIL;
    }
    return modelAdapter->Optimize(options, graphPtr, stage);
}

int HIAI_OpsKernelInfoStore_CheckSupported(const ge::ComputeGraphPtr graphPtr, vector<string>& checkRst)
{
    std::shared_ptr<CloudModelAdapter> modelAdapter = hiai::make_shared_nothrow<CloudModelAdapter>();
    if (modelAdapter == nullptr) {
        FMK_LOGE("cloud model adapter is null!");
        return FAIL;
    }
    return modelAdapter->CheckSupported(graphPtr, checkRst);
}

int HIAI_ModelBuilder_Build(
    const ge::CompileOptions& options, ge::ComputeGraphPtr graphPtr, shared_ptr<ge::CompiledTarget>& compiledTarget)
{
    std::shared_ptr<CloudModelAdapter> modelAdapter = hiai::make_shared_nothrow<CloudModelAdapter>();
    if (modelAdapter == nullptr) {
        FMK_LOGE("cloud model adapter is null!");
        return FAIL;
    }
    return modelAdapter->Compile(options, graphPtr, compiledTarget);
}

int HIAI_Init(std::map<std::string, std::string>& options)
{
    hiai::Initializer::Instance()->Init(options);
    return 0;
}

void HIAI_Stats_PluginVersion(const char* pluginVersionNum)
{
    if (pluginVersionNum != nullptr) {
        string pluginVersion = std::string(pluginVersionNum);
        constexpr int MAX_PLUGIN_VERSIONNAME_LEN = 13;
        if (pluginVersion.size() > MAX_PLUGIN_VERSIONNAME_LEN) {
            pluginVersion = "";
        }
        FMK_LOGI("GetComputeCapabilityVersion %s", pluginVersionNum);
        int32_t uid = getpid();
        std::string processName = hiai::GetProcessName();
        std::string interfaceName = std::string("CloudServiceEnable;") + std::string("pluginVersion:") + pluginVersion;

        AiStatsLogBuilder statslogbuilder;
        statslogbuilder.Stats(uid, AI_STATS_HIAI_MNGR, interfaceName.c_str(), processName.c_str(), 0);
    }
}
#endif

namespace hiai {
Status CloudModelAdapter::CheckSupported(const ge::ComputeGraphPtr graphPtr, vector<string>& nodeNameVec)
{
    if (OpKernelStoreManager::GetInstance() == nullptr) {
        FMK_LOGE("Get OpKernelStoreManager instance failed!");
        return FAIL;
    }
    const std::set<std::string> clNames = OpKernelStoreManager::GetInstance()->GetLogicCLName();
    if (clNames.empty()) {
        FMK_LOGE("Initialize fail, no cl registered.");
        return FAIL;
    }
    for (auto& it : clNames) {
        shared_ptr<OpsKernelInfoStore> opKernel = OpKernelStoreManager::GetInstance()->GetOpsKernelInfoStore(it);
        DOMI_IF_BOOL_EXEC(opKernel == nullptr, return FAIL);
        auto tmpVec = opKernel->CheckSupported(graphPtr);

        nodeNameVec.insert(nodeNameVec.end(), tmpVec.begin(), tmpVec.end());
    }
    if (nodeNameVec.empty()) {
        FMK_LOGI("nodeNameVec empty");
        return ge::SUCCESS;
    }
    std::set<string> nameVec(nodeNameVec.begin(), nodeNameVec.end());
    nodeNameVec.assign(nameVec.begin(), nameVec.end());
    return ge::SUCCESS;
}

Status CloudModelAdapter::Optimize(ge::GraphOptimizerOptions& options, ge::ComputeGraphPtr& graphPtr, int& stage)
{
    if (OpKernelStoreManager::GetInstance() == nullptr) {
        FMK_LOGE("Get OpKernelStoreManager instance failed!");
        return FAIL;
    }

    const std::set<std::string> clNames = OpKernelStoreManager::GetInstance()->GetLogicCLName();
    if (clNames.empty()) {
        FMK_LOGE("Initialize fail, no cl registered.");
        return FAIL;
    }
    for (auto& it : clNames) {
        vector<shared_ptr<GraphOptimizer>> graphOptimizers;
        OpKernelStoreManager::GetInstance()->GetGraphOptimizers(
            it, static_cast<OptimizationStage>(stage), graphOptimizers);
        if (graphOptimizers.empty()) {
            FMK_LOGI("%s graphOptimizers empty", it.c_str());
            continue;
        }
        FMK_LOGI("optimize clname:%s!, stage:%d", it.c_str(), stage);
        for (auto& item : graphOptimizers) {
            if (item->Optimize(options, graphPtr) != ge::SUCCESS) {
                FMK_LOGE("%s graphOptimizers fail", it.c_str());
                return FAIL;
            }
        }
    }
    return ge::SUCCESS;
}

static void InsertNetOutputSrcName(ge::ComputeGraphPtr graph)
{
    for (const ge::NodePtr& n : graph->GetDirectNodes()) {
        if (n == nullptr) {
            continue;
        }
        auto nodeOpDesc = n->GetOpDesc();
        if (nodeOpDesc == nullptr) {
            continue;
        }
        if (nodeOpDesc->GetType() == ge::NETOUTPUT) {
            // netoutput op exist, record srcName&srcIndex of netoutput
            vector<string> srcNameVec;
            vector<int64_t> srcIndexVec;
            for (auto& inAnchor : n->GetAllInDataAnchors()) {
                if (inAnchor && inAnchor->GetPeerOutAnchor() && inAnchor->GetPeerOutAnchor()->GetOwnerNode()) {
                    auto peerOutAnchor = inAnchor->GetPeerOutAnchor();
                    auto peerSrcNode = peerOutAnchor->GetOwnerNode();
                    srcNameVec.emplace_back(peerSrcNode->ROLE(NodeSpec).Name());
                    srcIndexVec.emplace_back(peerOutAnchor->GetIdx());
                }
            }
            nodeOpDesc->SetSrcName(srcNameVec);
            nodeOpDesc->SetSrcIndex(srcIndexVec);
            return;
        }
    }
    return;
}

Status CloudModelAdapter::Compile(
    const ge::CompileOptions& options, ge::ComputeGraphPtr graphPtr, shared_ptr<ge::CompiledTarget>& compiledTarget)
{
    if (MemoryAllocatorFactory::Instance() == nullptr) {
        FMK_LOGE("MemoryAllocatorFactory::Instance() is null");
        return FAIL;
    }

    InsertNetOutputSrcName(graphPtr);
    hiai::ModelCompileOptions compileOption;
    compileOption.memAllocator = MemoryAllocatorFactory::Instance()->CreateAllocator(hiai::MemoryType::DDR);
    compileOption.isNeedKernelBin = true;
    compileOption.onlyInferShape = true;
    compileOption.tuningStrategy = options.tuningStrategy;

    if (!hiai::IRTransformer::TransferToStandard(graphPtr)) {
        return FAIL;
    }

    ModelType type = HCS_PARTITION_MODEL;
    std::shared_ptr<hiai::ICompiledModel> iCompiledModel = hiai::CompiledModelFactory::GetInstance().Create(type);
    if (iCompiledModel == nullptr) {
        FMK_LOGE("Create Compilde Model failed!");
        return FAIL;
    }
    std::shared_ptr<hiai::ModelCompiler> iCompiler = hiai::ModelCompilerFactory::GetInstance().Create(type);
    if (iCompiler == nullptr) {
        FMK_LOGE("Create Model Compiler failed!");
        return FAIL;
    }

    Status ret = iCompiler->Compile(compileOption, graphPtr, iCompiledModel);
    if (ret != ge::SUCCESS) {
        FMK_LOGE("Generate om graph failed, error code: %d", ret);
        return FAIL;
    }

    ge::BaseBuffer modelBuffer;
    ret = iCompiledModel->SaveToBuffer(modelBuffer);
    uint8_t* data = modelBuffer.GetData();
    if (ret != ge::SUCCESS) {
        FMK_LOGE("Save compiled model failed, error code: %d", ret);
        delete[] data;
        return FAIL;
    }

    shared_ptr<dnnacl::DnnaclCompiledTarget> dnnaclCompiledTarget =
        hiai::make_shared_nothrow<dnnacl::DnnaclCompiledTarget>();
    if (dnnaclCompiledTarget == nullptr) {
        FMK_LOGE("create dnnaclCompiledTarget fail");
        delete[] data;
        return FAIL;
    }
    if (dnnaclCompiledTarget->SetData(data, modelBuffer.GetSize(), false) != ge::SUCCESS) {
        FMK_LOGE("dnnaclCompiledTarget set data fail");
        delete[] data;
        return FAIL;
    }
    compiledTarget = static_pointer_cast<ge::CompiledTarget>(dnnaclCompiledTarget);
    return ge::SUCCESS;
}
} // namespace hiai
